{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2023-05-09T05:04:41.203055400Z",
     "start_time": "2023-05-09T05:04:41.180057800Z"
    }
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append('..')\n",
    "import numpy as np\n",
    "from common.time_layers import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "class SimpleRnnlm:\n",
    "    def __init__(self, vocab_size, wordvec_size, hidden_size):\n",
    "        V, D, H = vocab_size, wordvec_size, hidden_size\n",
    "        rn = np.random.randn\n",
    "\n",
    "        # 初始化权重\n",
    "        embed_W = (rn(V, D) / 100).astype('f')\n",
    "        rnn_Wx = (rn(D, H) / np.sqrt(D)).astype('f')\n",
    "        rnn_Wh = (rn(H, H) / np.sqrt(H)).astype('f')\n",
    "        rnn_b = np.zeros(H).astype('f')\n",
    "        affine_W = (rn(H, V) / np.sqrt(H)).astype('f')\n",
    "        affine_b = np.zeros(V).astype('f')\n",
    "\n",
    "        # 生成层\n",
    "        self.layers = [\n",
    "            TimeEmbedding(embed_W),\n",
    "            TimeRNN(rnn_Wx, rnn_Wh, rnn_b, stateful=True),\n",
    "            TimeAffine(affine_W, affine_b)\n",
    "        ]\n",
    "\n",
    "        self.loss_layer = TimeSoftmaxWithLoss()\n",
    "        self.rnn_layer = self.layers[1]\n",
    "\n",
    "        # 将所有的权重和梯度整理到列表中\n",
    "        self.params, self.grads = [], []\n",
    "        for layer in self.layers:\n",
    "            self.params += layer.params\n",
    "            self.grads += layer.grads\n",
    "\n",
    "    def forward(self, xs, ts):\n",
    "        for layer in self.layers:\n",
    "            xs = layer.forward(xs)\n",
    "        loss = self.loss_layer.forward(xs, ts)\n",
    "        return loss\n",
    "\n",
    "    def backward(self, dout=1):\n",
    "        dout = self.loss_layer.backward(dout)\n",
    "        for layer in reversed(self.layers):\n",
    "            dout = layer.backward(dout)\n",
    "        return dout\n",
    "\n",
    "    def reset_state(self):\n",
    "        self.rnn_layer.reset_state()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T05:35:13.617267900Z",
     "start_time": "2023-05-09T05:35:13.611270400Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append('..')\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from common.optimizer import SGD\n",
    "from dataset import ptb"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T05:33:31.223506600Z",
     "start_time": "2023-05-09T05:33:31.215511200Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "# 设定超参数\n",
    "batch_size = 10\n",
    "wordvec_size = 100\n",
    "hidden_size = 100 # RNN的隐藏状态向量的元素个数\n",
    "time_size = 5 # Truncated BPTT的时间跨度大小\n",
    "lr = 0.1\n",
    "max_epoch = 100"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T05:33:56.553941800Z",
     "start_time": "2023-05-09T05:33:56.538943100Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "# 读入训练数据（缩小了数据集）\n",
    "corpus, word_to_id, id_to_word = ptb.load_data('train')\n",
    "corpus_size = 1000\n",
    "corpus = corpus[:corpus_size]\n",
    "vocab_size = int(max(corpus) + 1)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T05:34:08.121676Z",
     "start_time": "2023-05-09T05:34:08.110677300Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "corpus size: 1000, vocabulary size: 418\n"
     ]
    }
   ],
   "source": [
    "xs = corpus[:-1] # 输入\n",
    "ts = corpus[1:] # 输出（监督标签）\n",
    "data_size = len(xs)\n",
    "print('corpus size: %d, vocabulary size: %d' % (corpus_size, vocab_size))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T05:34:24.489741400Z",
     "start_time": "2023-05-09T05:34:24.476742Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [],
   "source": [
    "# 学习用的参数\n",
    "max_iters = data_size // (batch_size * time_size)\n",
    "time_idx = 0\n",
    "total_loss = 0\n",
    "loss_count = 0\n",
    "ppl_list = []"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T05:34:41.261999Z",
     "start_time": "2023-05-09T05:34:41.251Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [],
   "source": [
    "# 生成模型\n",
    "model = SimpleRnnlm(vocab_size, wordvec_size, hidden_size)\n",
    "optimizer = SGD(lr)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T05:35:17.045609Z",
     "start_time": "2023-05-09T05:35:17.021623700Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [],
   "source": [
    "from common.trainer import RnnlmTrainer"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T05:40:11.480274900Z",
     "start_time": "2023-05-09T05:40:11.467274200Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [],
   "source": [
    "trainer = RnnlmTrainer(model, optimizer)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T05:40:20.046789200Z",
     "start_time": "2023-05-09T05:40:20.040790200Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| epoch 1 |  iter 1 / 19 | time 0[s] | perplexity 417.61\n",
      "| epoch 2 |  iter 1 / 19 | time 0[s] | perplexity 402.39\n",
      "| epoch 3 |  iter 1 / 19 | time 0[s] | perplexity 310.19\n",
      "| epoch 4 |  iter 1 / 19 | time 0[s] | perplexity 234.39\n",
      "| epoch 5 |  iter 1 / 19 | time 0[s] | perplexity 214.16\n",
      "| epoch 6 |  iter 1 / 19 | time 0[s] | perplexity 212.12\n",
      "| epoch 7 |  iter 1 / 19 | time 0[s] | perplexity 202.63\n",
      "| epoch 8 |  iter 1 / 19 | time 0[s] | perplexity 201.66\n",
      "| epoch 9 |  iter 1 / 19 | time 0[s] | perplexity 195.12\n",
      "| epoch 10 |  iter 1 / 19 | time 0[s] | perplexity 190.62\n",
      "| epoch 11 |  iter 1 / 19 | time 0[s] | perplexity 193.12\n",
      "| epoch 12 |  iter 1 / 19 | time 0[s] | perplexity 188.80\n",
      "| epoch 13 |  iter 1 / 19 | time 0[s] | perplexity 191.83\n",
      "| epoch 14 |  iter 1 / 19 | time 0[s] | perplexity 186.91\n",
      "| epoch 15 |  iter 1 / 19 | time 0[s] | perplexity 185.39\n",
      "| epoch 16 |  iter 1 / 19 | time 0[s] | perplexity 189.59\n",
      "| epoch 17 |  iter 1 / 19 | time 0[s] | perplexity 187.64\n",
      "| epoch 18 |  iter 1 / 19 | time 0[s] | perplexity 182.57\n",
      "| epoch 19 |  iter 1 / 19 | time 0[s] | perplexity 181.99\n",
      "| epoch 20 |  iter 1 / 19 | time 0[s] | perplexity 180.28\n",
      "| epoch 21 |  iter 1 / 19 | time 0[s] | perplexity 179.21\n",
      "| epoch 22 |  iter 1 / 19 | time 0[s] | perplexity 176.85\n",
      "| epoch 23 |  iter 1 / 19 | time 0[s] | perplexity 178.27\n",
      "| epoch 24 |  iter 1 / 19 | time 0[s] | perplexity 174.58\n",
      "| epoch 25 |  iter 1 / 19 | time 0[s] | perplexity 173.09\n",
      "| epoch 26 |  iter 1 / 19 | time 0[s] | perplexity 172.69\n",
      "| epoch 27 |  iter 1 / 19 | time 0[s] | perplexity 168.97\n",
      "| epoch 28 |  iter 1 / 19 | time 0[s] | perplexity 169.78\n",
      "| epoch 29 |  iter 1 / 19 | time 0[s] | perplexity 163.76\n",
      "| epoch 30 |  iter 1 / 19 | time 0[s] | perplexity 160.76\n",
      "| epoch 31 |  iter 1 / 19 | time 0[s] | perplexity 158.52\n",
      "| epoch 32 |  iter 1 / 19 | time 0[s] | perplexity 152.35\n",
      "| epoch 33 |  iter 1 / 19 | time 0[s] | perplexity 153.21\n",
      "| epoch 34 |  iter 1 / 19 | time 1[s] | perplexity 153.67\n",
      "| epoch 35 |  iter 1 / 19 | time 1[s] | perplexity 147.45\n",
      "| epoch 36 |  iter 1 / 19 | time 1[s] | perplexity 142.92\n",
      "| epoch 37 |  iter 1 / 19 | time 1[s] | perplexity 141.94\n",
      "| epoch 38 |  iter 1 / 19 | time 1[s] | perplexity 133.85\n",
      "| epoch 39 |  iter 1 / 19 | time 1[s] | perplexity 131.98\n",
      "| epoch 40 |  iter 1 / 19 | time 1[s] | perplexity 126.18\n",
      "| epoch 41 |  iter 1 / 19 | time 1[s] | perplexity 120.06\n",
      "| epoch 42 |  iter 1 / 19 | time 1[s] | perplexity 119.86\n",
      "| epoch 43 |  iter 1 / 19 | time 1[s] | perplexity 113.61\n",
      "| epoch 44 |  iter 1 / 19 | time 1[s] | perplexity 111.73\n",
      "| epoch 45 |  iter 1 / 19 | time 1[s] | perplexity 103.56\n",
      "| epoch 46 |  iter 1 / 19 | time 1[s] | perplexity 100.25\n",
      "| epoch 47 |  iter 1 / 19 | time 1[s] | perplexity 98.67\n",
      "| epoch 48 |  iter 1 / 19 | time 1[s] | perplexity 93.54\n",
      "| epoch 49 |  iter 1 / 19 | time 1[s] | perplexity 91.70\n",
      "| epoch 50 |  iter 1 / 19 | time 1[s] | perplexity 84.07\n",
      "| epoch 51 |  iter 1 / 19 | time 1[s] | perplexity 81.20\n",
      "| epoch 52 |  iter 1 / 19 | time 1[s] | perplexity 77.45\n",
      "| epoch 53 |  iter 1 / 19 | time 1[s] | perplexity 74.60\n",
      "| epoch 54 |  iter 1 / 19 | time 1[s] | perplexity 71.16\n",
      "| epoch 55 |  iter 1 / 19 | time 1[s] | perplexity 66.86\n",
      "| epoch 56 |  iter 1 / 19 | time 1[s] | perplexity 64.37\n",
      "| epoch 57 |  iter 1 / 19 | time 1[s] | perplexity 63.35\n",
      "| epoch 58 |  iter 1 / 19 | time 1[s] | perplexity 57.80\n",
      "| epoch 59 |  iter 1 / 19 | time 1[s] | perplexity 57.00\n",
      "| epoch 60 |  iter 1 / 19 | time 1[s] | perplexity 50.98\n",
      "| epoch 61 |  iter 1 / 19 | time 1[s] | perplexity 49.12\n",
      "| epoch 62 |  iter 1 / 19 | time 1[s] | perplexity 48.23\n",
      "| epoch 63 |  iter 1 / 19 | time 1[s] | perplexity 46.00\n",
      "| epoch 64 |  iter 1 / 19 | time 1[s] | perplexity 42.25\n",
      "| epoch 65 |  iter 1 / 19 | time 1[s] | perplexity 38.82\n",
      "| epoch 66 |  iter 1 / 19 | time 1[s] | perplexity 37.70\n",
      "| epoch 67 |  iter 1 / 19 | time 2[s] | perplexity 35.61\n",
      "| epoch 68 |  iter 1 / 19 | time 2[s] | perplexity 34.02\n",
      "| epoch 69 |  iter 1 / 19 | time 2[s] | perplexity 32.15\n",
      "| epoch 70 |  iter 1 / 19 | time 2[s] | perplexity 29.73\n",
      "| epoch 71 |  iter 1 / 19 | time 2[s] | perplexity 28.59\n",
      "| epoch 72 |  iter 1 / 19 | time 2[s] | perplexity 25.94\n",
      "| epoch 73 |  iter 1 / 19 | time 2[s] | perplexity 25.10\n",
      "| epoch 74 |  iter 1 / 19 | time 2[s] | perplexity 23.91\n",
      "| epoch 75 |  iter 1 / 19 | time 2[s] | perplexity 23.30\n",
      "| epoch 76 |  iter 1 / 19 | time 2[s] | perplexity 21.92\n",
      "| epoch 77 |  iter 1 / 19 | time 2[s] | perplexity 21.00\n",
      "| epoch 78 |  iter 1 / 19 | time 2[s] | perplexity 18.99\n",
      "| epoch 79 |  iter 1 / 19 | time 2[s] | perplexity 18.50\n",
      "| epoch 80 |  iter 1 / 19 | time 2[s] | perplexity 17.98\n",
      "| epoch 81 |  iter 1 / 19 | time 2[s] | perplexity 15.97\n",
      "| epoch 82 |  iter 1 / 19 | time 2[s] | perplexity 15.83\n",
      "| epoch 83 |  iter 1 / 19 | time 2[s] | perplexity 14.93\n",
      "| epoch 84 |  iter 1 / 19 | time 2[s] | perplexity 14.87\n",
      "| epoch 85 |  iter 1 / 19 | time 2[s] | perplexity 13.32\n",
      "| epoch 86 |  iter 1 / 19 | time 2[s] | perplexity 12.85\n",
      "| epoch 87 |  iter 1 / 19 | time 2[s] | perplexity 12.07\n",
      "| epoch 88 |  iter 1 / 19 | time 2[s] | perplexity 11.13\n",
      "| epoch 89 |  iter 1 / 19 | time 2[s] | perplexity 10.89\n",
      "| epoch 90 |  iter 1 / 19 | time 2[s] | perplexity 10.28\n",
      "| epoch 91 |  iter 1 / 19 | time 2[s] | perplexity 9.31\n",
      "| epoch 92 |  iter 1 / 19 | time 2[s] | perplexity 9.28\n",
      "| epoch 93 |  iter 1 / 19 | time 2[s] | perplexity 9.29\n",
      "| epoch 94 |  iter 1 / 19 | time 2[s] | perplexity 8.20\n",
      "| epoch 95 |  iter 1 / 19 | time 2[s] | perplexity 8.07\n",
      "| epoch 96 |  iter 1 / 19 | time 2[s] | perplexity 7.49\n",
      "| epoch 97 |  iter 1 / 19 | time 2[s] | perplexity 7.21\n",
      "| epoch 98 |  iter 1 / 19 | time 2[s] | perplexity 6.66\n",
      "| epoch 99 |  iter 1 / 19 | time 2[s] | perplexity 6.39\n",
      "| epoch 100 |  iter 1 / 19 | time 3[s] | perplexity 6.37\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(xs, ts, max_epoch, batch_size, time_size)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2023-05-09T05:40:43.104900300Z",
     "start_time": "2023-05-09T05:40:40.044543300Z"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
